# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Union

import torch

from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.constants import (
    DTYPE_TO_SHORT_STR,
    SUPPORTED_ELEM_DTYPES,
)
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference


class MXFP8Dim1CastKernelChoice(Enum):
    """
    Defines which kernel to use for mxfp8 casting. Currently custom casting kernels are
    only for scaling along dim1, and torch native code is always used for scaling along dim0.
    """

    TRITON = "triton"
    CUDA = "cuda"
    TORCH = "torch"


# Pre-made recipes for common configurations
class MXLinearRecipeName(Enum):
    MXFP8_EMULATED = "mxfp8_emulated"
    MXFP8_CUBLAS = "mxfp8_cublas"
    MXFP8_CUBLAS_RCEIL = "mxfp8_cublas_rceil"
    MXFP4_EMULATED = "mxfp4_emulated"
    MXFP4_CUTLASS = "mxfp4_cutlass"


class ScaleCalculationMode(Enum):
    """
    Enum representing the different methods for calculating MX block scaling.
    There are four methods available:

    FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
           It result in overflow issues for large values and bad for gradient quantization.

    RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos.
           This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization
           Section "Computing scaling and conversion factors for FP8 with UE8M0 scales"

    CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
           It uses X = 2^ceil(log2(max_abs(v))-max_exp).

    EVEN: This method is a trade-off between FLOOR and CEIL. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
           It provides better accuracy for MX4 training compared to FLOOR and CEIL.
           Note: EVEN does not work with torch.compile yet:
           https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4

    """

    FLOOR = "floor"
    RCEIL = "rceil"
    CEIL = "ceil"
    EVEN = "even"


def _validate_elem_dtype(elem_dtype):
    assert elem_dtype in SUPPORTED_ELEM_DTYPES, (
        f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}"
    )


def _validate_kernel_preference(kernel_preference, block_size, elem_dtype):
    if kernel_preference == KernelPreference.AUTO:
        if elem_dtype in (torch.float8_e4m3fn, torch.float4_e2m1fn_x2):
            assert block_size == 32, f"block_size must be 32, got {block_size}"
        else:
            raise AssertionError(
                f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
            )
    else:
        assert kernel_preference == KernelPreference.EMULATED, (
            f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
        )


def _validate_mxfp8_cast_kernel_choice(
    mxfp8_cast_kernel_choice, scale_calculation_mode
):
    if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
        assert scale_calculation_mode == ScaleCalculationMode.FLOOR, (
            f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 triton cast"
        )
    elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
        assert scale_calculation_mode in (
            ScaleCalculationMode.FLOOR,
            ScaleCalculationMode.RCEIL,
        ), (
            f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 cuda cast"
        )


@dataclass
class MXLinearConfig(AOBaseConfig):
    # block size for scaling, default is 32 to match
    # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
    # section 5.2
    block_size: int = 32

    # element dtype, used for activations, weights and gradients
    elem_dtype: Any = torch.float8_e4m3fn

    # overrides for element dtype for weights and gradients
    # TODO(future PR): refactor to make this cleaner
    elem_dtype_weight_override: Optional[Any] = None
    elem_dtype_grad_output_override: Optional[Any] = None

    # defines the kernel preference, if the chosen kernel is not supported
    # on the given hardware an exception will be thrown
    kernel_preference: KernelPreference = KernelPreference.EMULATED

    # define which kernel to use for mxfp8 casting
    # TODO(1945): remove this config option once torch.compile gives us
    # a fast kernel
    mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice = (
        MXFP8Dim1CastKernelChoice.TORCH
    )

    scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR

    def __post_init__(self):
        _validate_elem_dtype(self.elem_dtype)
        _validate_kernel_preference(
            self.kernel_preference, self.block_size, self.elem_dtype
        )
        if self.elem_dtype_weight_override is not None:
            _validate_elem_dtype(self.elem_dtype_weight_override)
            assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
        if self.elem_dtype_grad_output_override is not None:
            _validate_elem_dtype(self.elem_dtype_grad_output_override)
            assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
        _validate_mxfp8_cast_kernel_choice(
            self.mxfp8_cast_kernel_choice, self.scale_calculation_mode
        )

    @staticmethod
    def from_recipe_name(
        recipe_name: Union[MXLinearRecipeName, str],
    ) -> "MXLinearConfig":
        """
        Input: `MXLinearRecipeName` value, or a string representing a `MXLinearRecipeName` value
        Output: a `MXLinearConfig` configured to implement the specified recipe
        """
        if type(recipe_name) == str:
            valid_names = [n.value for n in MXLinearRecipeName]
            assert recipe_name in valid_names, (
                f"recipe_name {recipe_name} not in valid names {valid_names}"
            )
            recipe_name = MXLinearRecipeName(recipe_name)

        if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
            return MXLinearConfig()
        elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
            return MXLinearConfig(
                kernel_preference=KernelPreference.AUTO,
                mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
            )
        elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
            return MXLinearConfig(
                kernel_preference=KernelPreference.AUTO,
                mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
                scale_calculation_mode=ScaleCalculationMode.RCEIL,
            )
        elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
            return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
        elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
            return MXLinearConfig(
                elem_dtype=torch.float4_e2m1fn_x2,
                kernel_preference=KernelPreference.AUTO,
            )
        else:
            raise AssertionError(f"unknown recipe_name {recipe_name}")

    def short_str(self) -> str:
        """
        Returns a concise representation of the current config.
        """
        s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
        if self.elem_dtype_weight_override is not None:
            s += (
                f", lp_w_override={DTYPE_TO_SHORT_STR[self.elem_dtype_weight_override]}"
            )
        if self.elem_dtype_grad_output_override is not None:
            s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
        s += f", kernel={self.kernel_preference.value}"
        s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
        if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
            s += f", scale_calculation_mode={self.scale_calculation_mode}"
        return s
